{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## import modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: torch==2.2 in /opt/conda/lib/python3.10/site-packages (2.2.0)\n",
      "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (3.13.1)\n",
      "Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (4.9.0)\n",
      "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (1.12)\n",
      "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (2.8.5)\n",
      "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (3.1.2)\n",
      "Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (2023.12.2)\n",
      "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (8.9.2.26)\n",
      "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.3.1)\n",
      "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (11.0.2.54)\n",
      "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (10.3.2.106)\n",
      "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (11.4.5.107)\n",
      "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.0.106)\n",
      "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (2.19.3)\n",
      "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.105)\n",
      "Requirement already satisfied: triton==2.2.0 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (2.2.0)\n",
      "Requirement already satisfied: nvidia-nvjitlink-cu12 in /opt/conda/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch==2.2) (12.3.101)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch==2.2) (2.1.3)\n",
      "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch==2.2) (1.3.0)\n",
      "Requirement already satisfied: torchvision==0.17 in /opt/conda/lib/python3.10/site-packages (0.17.0)\n",
      "Requirement already satisfied: numpy in /opt/conda/lib/python3.10/site-packages (from torchvision==0.17) (1.23.4)\n",
      "Requirement already satisfied: requests in /opt/conda/lib/python3.10/site-packages (from torchvision==0.17) (2.28.1)\n",
      "Requirement already satisfied: torch==2.2.0 in /opt/conda/lib/python3.10/site-packages (from torchvision==0.17) (2.2.0)\n",
      "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /opt/conda/lib/python3.10/site-packages (from torchvision==0.17) (9.3.0)\n",
      "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (3.13.1)\n",
      "Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (4.9.0)\n",
      "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (1.12)\n",
      "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (2.8.5)\n",
      "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (3.1.2)\n",
      "Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (2023.12.2)\n",
      "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (8.9.2.26)\n",
      "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (12.1.3.1)\n",
      "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (11.0.2.54)\n",
      "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (10.3.2.106)\n",
      "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (11.4.5.107)\n",
      "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (12.1.0.106)\n",
      "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (2.19.3)\n",
      "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (12.1.105)\n",
      "Requirement already satisfied: triton==2.2.0 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (2.2.0)\n",
      "Requirement already satisfied: nvidia-nvjitlink-cu12 in /opt/conda/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch==2.2.0->torchvision==0.17) (12.3.101)\n",
      "Requirement already satisfied: charset-normalizer<3,>=2 in /opt/conda/lib/python3.10/site-packages (from requests->torchvision==0.17) (2.1.0)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests->torchvision==0.17) (3.3)\n",
      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests->torchvision==0.17) (1.26.10)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests->torchvision==0.17) (2022.6.15)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch==2.2.0->torchvision==0.17) (2.1.3)\n",
      "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch==2.2.0->torchvision==0.17) (1.3.0)\n",
      "Requirement already satisfied: matplotlib==3.5.2 in /opt/conda/lib/python3.10/site-packages (3.5.2)\n",
      "Requirement already satisfied: cycler>=0.10 in /opt/conda/lib/python3.10/site-packages (from matplotlib==3.5.2) (0.12.1)\n",
      "Requirement already satisfied: fonttools>=4.22.0 in /opt/conda/lib/python3.10/site-packages (from matplotlib==3.5.2) (4.46.0)\n",
      "Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib==3.5.2) (1.4.5)\n",
      "Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.10/site-packages (from matplotlib==3.5.2) (1.23.4)\n",
      "Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.10/site-packages (from matplotlib==3.5.2) (21.3)\n",
      "Requirement already satisfied: pillow>=6.2.0 in /opt/conda/lib/python3.10/site-packages (from matplotlib==3.5.2) (9.3.0)\n",
      "Requirement already satisfied: pyparsing>=2.2.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib==3.5.2) (3.0.9)\n",
      "Requirement already satisfied: python-dateutil>=2.7 in /opt/conda/lib/python3.10/site-packages (from matplotlib==3.5.2) (2.8.2)\n",
      "Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib==3.5.2) (1.16.0)\n"
     ]
    }
   ],
   "source": [
    "!pip install torch==2.2\n",
    "!pip install torchvision==0.17\n",
    "!pip install matplotlib==3.5.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
      "  _torch_pytree._register_pytree_node(\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## define model architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ConvNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(ConvNet, self).__init__()\n",
    "        self.cn1 = nn.Conv2d(1, 16, 3, 1)\n",
    "        self.cn2 = nn.Conv2d(16, 32, 3, 1)\n",
    "        self.dp1 = nn.Dropout2d(0.10)\n",
    "        self.dp2 = nn.Dropout2d(0.25)\n",
    "        self.fc1 = nn.Linear(4608, 64) # 4608 is basically 12 X 12 X 32\n",
    "        self.fc2 = nn.Linear(64, 10)\n",
    " \n",
    "    def forward(self, x):\n",
    "        x = self.cn1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.cn2(x)\n",
    "        x = F.relu(x)\n",
    "        x = F.max_pool2d(x, 2)\n",
    "        x = self.dp1(x)\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = self.fc1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.dp2(x)\n",
    "        x = self.fc2(x)\n",
    "        op = F.log_softmax(x, dim=1)\n",
    "        return op"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## define training and inference routines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, device, train_dataloader, optim, epoch):\n",
    "    model.train()\n",
    "    for b_i, (X, y) in enumerate(train_dataloader):\n",
    "        X, y = X.to(device), y.to(device)\n",
    "        optim.zero_grad()\n",
    "        pred_prob = model(X)\n",
    "        loss = F.nll_loss(pred_prob, y) # nll is the negative likelihood loss\n",
    "        loss.backward()\n",
    "        optim.step()\n",
    "        if b_i % 10 == 0:\n",
    "            print('epoch: {} [{}/{} ({:.0f}%)]\\t training loss: {:.6f}'.format(\n",
    "                epoch, b_i * len(X), len(train_dataloader.dataset),\n",
    "                100. * b_i / len(train_dataloader), loss.item()))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(model, device, test_dataloader):\n",
    "    model.eval()\n",
    "    loss = 0\n",
    "    success = 0\n",
    "    with torch.no_grad():\n",
    "        for X, y in test_dataloader:\n",
    "            X, y = X.to(device), y.to(device)\n",
    "            pred_prob = model(X)\n",
    "            loss += F.nll_loss(pred_prob, y, reduction='sum').item()  # loss summed across the batch\n",
    "            pred = pred_prob.argmax(dim=1, keepdim=True)  # us argmax to get the most likely prediction\n",
    "            success += pred.eq(y.view_as(pred)).sum().item()\n",
    "\n",
    "    loss /= len(test_dataloader.dataset)\n",
    "\n",
    "    print('\\nTest dataset: Overall Loss: {:.4f}, Overall Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
    "        loss, success, len(test_dataloader.dataset),\n",
    "        100. * success / len(test_dataloader.dataset)))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## create data loaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The mean and standard deviation values are calculated as the mean of all pixel values of all images in the training dataset\n",
    "train_dataloader = torch.utils.data.DataLoader(\n",
    "    datasets.MNIST('../data', train=True, download=True,\n",
    "                   transform=transforms.Compose([\n",
    "                       transforms.ToTensor(),\n",
    "                       transforms.Normalize((0.1302,), (0.3069,))])), # train_X.mean()/256. and train_X.std()/256.\n",
    "    batch_size=32, shuffle=True)\n",
    "\n",
    "test_dataloader = torch.utils.data.DataLoader(\n",
    "    datasets.MNIST('../data', train=False, \n",
    "                   transform=transforms.Compose([\n",
    "                       transforms.ToTensor(),\n",
    "                       transforms.Normalize((0.1302,), (0.3069,)) \n",
    "                   ])),\n",
    "    batch_size=500, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## define optimizer and run training epochs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(0)\n",
    "device = torch.device(\"cpu\")\n",
    "\n",
    "model = ConvNet()\n",
    "optimizer = optim.Adadelta(model.parameters(), lr=0.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## model training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.10/site-packages/torch/nn/functional.py:1347: UserWarning: dropout2d: Received a 2-D input to dropout2d, which is deprecated and will result in an error in a future release. To retain the behavior and silence this warning, please use dropout instead. Note that dropout2d exists to provide channel-wise dropout on inputs with 2 spatial dimensions, a channel dimension, and an optional batch dimension (i.e. 3D or 4D inputs).\n",
      "  warnings.warn(warn_msg)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1 [0/60000 (0%)]\t training loss: 2.310609\n",
      "epoch: 1 [320/60000 (1%)]\t training loss: 1.924133\n",
      "epoch: 1 [640/60000 (1%)]\t training loss: 1.313337\n",
      "epoch: 1 [960/60000 (2%)]\t training loss: 0.796470\n",
      "epoch: 1 [1280/60000 (2%)]\t training loss: 0.819801\n",
      "epoch: 1 [1600/60000 (3%)]\t training loss: 0.678459\n",
      "epoch: 1 [1920/60000 (3%)]\t training loss: 0.477066\n",
      "epoch: 1 [2240/60000 (4%)]\t training loss: 0.527125\n",
      "epoch: 1 [2560/60000 (4%)]\t training loss: 0.469366\n",
      "epoch: 1 [2880/60000 (5%)]\t training loss: 0.239070\n",
      "epoch: 1 [3200/60000 (5%)]\t training loss: 0.523908\n",
      "epoch: 1 [3520/60000 (6%)]\t training loss: 0.267831\n",
      "epoch: 1 [3840/60000 (6%)]\t training loss: 0.464627\n",
      "epoch: 1 [4160/60000 (7%)]\t training loss: 0.413977\n",
      "epoch: 1 [4480/60000 (7%)]\t training loss: 0.324858\n",
      "epoch: 1 [4800/60000 (8%)]\t training loss: 0.495965\n",
      "epoch: 1 [5120/60000 (9%)]\t training loss: 0.153463\n",
      "epoch: 1 [5440/60000 (9%)]\t training loss: 0.377620\n",
      "epoch: 1 [5760/60000 (10%)]\t training loss: 0.097269\n",
      "epoch: 1 [6080/60000 (10%)]\t training loss: 0.181864\n",
      "epoch: 1 [6400/60000 (11%)]\t training loss: 0.295800\n",
      "epoch: 1 [6720/60000 (11%)]\t training loss: 0.082916\n",
      "epoch: 1 [7040/60000 (12%)]\t training loss: 0.353389\n",
      "epoch: 1 [7360/60000 (12%)]\t training loss: 0.040657\n",
      "epoch: 1 [7680/60000 (13%)]\t training loss: 0.365705\n",
      "epoch: 1 [8000/60000 (13%)]\t training loss: 0.231380\n",
      "epoch: 1 [8320/60000 (14%)]\t training loss: 0.321043\n",
      "epoch: 1 [8640/60000 (14%)]\t training loss: 0.211378\n",
      "epoch: 1 [8960/60000 (15%)]\t training loss: 0.054053\n",
      "epoch: 1 [9280/60000 (15%)]\t training loss: 0.337620\n",
      "epoch: 1 [9600/60000 (16%)]\t training loss: 0.216428\n",
      "epoch: 1 [9920/60000 (17%)]\t training loss: 0.174468\n",
      "epoch: 1 [10240/60000 (17%)]\t training loss: 0.213010\n",
      "epoch: 1 [10560/60000 (18%)]\t training loss: 0.357055\n",
      "epoch: 1 [10880/60000 (18%)]\t training loss: 0.070653\n",
      "epoch: 1 [11200/60000 (19%)]\t training loss: 0.055425\n",
      "epoch: 1 [11520/60000 (19%)]\t training loss: 0.230391\n",
      "epoch: 1 [11840/60000 (20%)]\t training loss: 0.048202\n",
      "epoch: 1 [12160/60000 (20%)]\t training loss: 0.403469\n",
      "epoch: 1 [12480/60000 (21%)]\t training loss: 0.025848\n",
      "epoch: 1 [12800/60000 (21%)]\t training loss: 0.104852\n",
      "epoch: 1 [13120/60000 (22%)]\t training loss: 0.101376\n",
      "epoch: 1 [13440/60000 (22%)]\t training loss: 0.178021\n",
      "epoch: 1 [13760/60000 (23%)]\t training loss: 0.070927\n",
      "epoch: 1 [14080/60000 (23%)]\t training loss: 0.048668\n",
      "epoch: 1 [14400/60000 (24%)]\t training loss: 0.090581\n",
      "epoch: 1 [14720/60000 (25%)]\t training loss: 0.165846\n",
      "epoch: 1 [15040/60000 (25%)]\t training loss: 0.324705\n",
      "epoch: 1 [15360/60000 (26%)]\t training loss: 0.077008\n",
      "epoch: 1 [15680/60000 (26%)]\t training loss: 0.128724\n",
      "epoch: 1 [16000/60000 (27%)]\t training loss: 0.084200\n",
      "epoch: 1 [16320/60000 (27%)]\t training loss: 0.158558\n",
      "epoch: 1 [16640/60000 (28%)]\t training loss: 0.040566\n",
      "epoch: 1 [16960/60000 (28%)]\t training loss: 0.287925\n",
      "epoch: 1 [17280/60000 (29%)]\t training loss: 0.057604\n",
      "epoch: 1 [17600/60000 (29%)]\t training loss: 0.031712\n",
      "epoch: 1 [17920/60000 (30%)]\t training loss: 0.046777\n",
      "epoch: 1 [18240/60000 (30%)]\t training loss: 0.025740\n",
      "epoch: 1 [18560/60000 (31%)]\t training loss: 0.065141\n",
      "epoch: 1 [18880/60000 (31%)]\t training loss: 0.220227\n",
      "epoch: 1 [19200/60000 (32%)]\t training loss: 0.102033\n",
      "epoch: 1 [19520/60000 (33%)]\t training loss: 0.068922\n",
      "epoch: 1 [19840/60000 (33%)]\t training loss: 0.075009\n",
      "epoch: 1 [20160/60000 (34%)]\t training loss: 0.111961\n",
      "epoch: 1 [20480/60000 (34%)]\t training loss: 0.056665\n",
      "epoch: 1 [20800/60000 (35%)]\t training loss: 0.234118\n",
      "epoch: 1 [21120/60000 (35%)]\t training loss: 0.303591\n",
      "epoch: 1 [21440/60000 (36%)]\t training loss: 0.217894\n",
      "epoch: 1 [21760/60000 (36%)]\t training loss: 0.091958\n",
      "epoch: 1 [22080/60000 (37%)]\t training loss: 0.079766\n",
      "epoch: 1 [22400/60000 (37%)]\t training loss: 0.151057\n",
      "epoch: 1 [22720/60000 (38%)]\t training loss: 0.262819\n",
      "epoch: 1 [23040/60000 (38%)]\t training loss: 0.061579\n",
      "epoch: 1 [23360/60000 (39%)]\t training loss: 0.135419\n",
      "epoch: 1 [23680/60000 (39%)]\t training loss: 0.285402\n",
      "epoch: 1 [24000/60000 (40%)]\t training loss: 0.148586\n",
      "epoch: 1 [24320/60000 (41%)]\t training loss: 0.066807\n",
      "epoch: 1 [24640/60000 (41%)]\t training loss: 0.126747\n",
      "epoch: 1 [24960/60000 (42%)]\t training loss: 0.026213\n",
      "epoch: 1 [25280/60000 (42%)]\t training loss: 0.024073\n",
      "epoch: 1 [25600/60000 (43%)]\t training loss: 0.600073\n",
      "epoch: 1 [25920/60000 (43%)]\t training loss: 0.497171\n",
      "epoch: 1 [26240/60000 (44%)]\t training loss: 0.118145\n",
      "epoch: 1 [26560/60000 (44%)]\t training loss: 0.040906\n",
      "epoch: 1 [26880/60000 (45%)]\t training loss: 0.154461\n",
      "epoch: 1 [27200/60000 (45%)]\t training loss: 0.013285\n",
      "epoch: 1 [27520/60000 (46%)]\t training loss: 0.027993\n",
      "epoch: 1 [27840/60000 (46%)]\t training loss: 0.127872\n",
      "epoch: 1 [28160/60000 (47%)]\t training loss: 0.040515\n",
      "epoch: 1 [28480/60000 (47%)]\t training loss: 0.402136\n",
      "epoch: 1 [28800/60000 (48%)]\t training loss: 0.138493\n",
      "epoch: 1 [29120/60000 (49%)]\t training loss: 0.207102\n",
      "epoch: 1 [29440/60000 (49%)]\t training loss: 0.012119\n",
      "epoch: 1 [29760/60000 (50%)]\t training loss: 0.054411\n",
      "epoch: 1 [30080/60000 (50%)]\t training loss: 0.546068\n",
      "epoch: 1 [30400/60000 (51%)]\t training loss: 0.144434\n",
      "epoch: 1 [30720/60000 (51%)]\t training loss: 0.225871\n",
      "epoch: 1 [31040/60000 (52%)]\t training loss: 0.069757\n",
      "epoch: 1 [31360/60000 (52%)]\t training loss: 0.068556\n",
      "epoch: 1 [31680/60000 (53%)]\t training loss: 0.034509\n",
      "epoch: 1 [32000/60000 (53%)]\t training loss: 0.015917\n",
      "epoch: 1 [32320/60000 (54%)]\t training loss: 0.107069\n",
      "epoch: 1 [32640/60000 (54%)]\t training loss: 0.110502\n",
      "epoch: 1 [32960/60000 (55%)]\t training loss: 0.077687\n",
      "epoch: 1 [33280/60000 (55%)]\t training loss: 0.276489\n",
      "epoch: 1 [33600/60000 (56%)]\t training loss: 0.186835\n",
      "epoch: 1 [33920/60000 (57%)]\t training loss: 0.017603\n",
      "epoch: 1 [34240/60000 (57%)]\t training loss: 0.025609\n",
      "epoch: 1 [34560/60000 (58%)]\t training loss: 0.003528\n",
      "epoch: 1 [34880/60000 (58%)]\t training loss: 0.123900\n",
      "epoch: 1 [35200/60000 (59%)]\t training loss: 0.229518\n",
      "epoch: 1 [35520/60000 (59%)]\t training loss: 0.131734\n",
      "epoch: 1 [35840/60000 (60%)]\t training loss: 0.015819\n",
      "epoch: 1 [36160/60000 (60%)]\t training loss: 0.165748\n",
      "epoch: 1 [36480/60000 (61%)]\t training loss: 0.117925\n",
      "epoch: 1 [36800/60000 (61%)]\t training loss: 0.140046\n",
      "epoch: 1 [37120/60000 (62%)]\t training loss: 0.043650\n",
      "epoch: 1 [37440/60000 (62%)]\t training loss: 0.216800\n",
      "epoch: 1 [37760/60000 (63%)]\t training loss: 0.032769\n",
      "epoch: 1 [38080/60000 (63%)]\t training loss: 0.113341\n",
      "epoch: 1 [38400/60000 (64%)]\t training loss: 0.074898\n",
      "epoch: 1 [38720/60000 (65%)]\t training loss: 0.488738\n",
      "epoch: 1 [39040/60000 (65%)]\t training loss: 0.048389\n",
      "epoch: 1 [39360/60000 (66%)]\t training loss: 0.176396\n",
      "epoch: 1 [39680/60000 (66%)]\t training loss: 0.021329\n",
      "epoch: 1 [40000/60000 (67%)]\t training loss: 0.106253\n",
      "epoch: 1 [40320/60000 (67%)]\t training loss: 0.065340\n",
      "epoch: 1 [40640/60000 (68%)]\t training loss: 0.202225\n",
      "epoch: 1 [40960/60000 (68%)]\t training loss: 0.194045\n",
      "epoch: 1 [41280/60000 (69%)]\t training loss: 0.213245\n",
      "epoch: 1 [41600/60000 (69%)]\t training loss: 0.155805\n",
      "epoch: 1 [41920/60000 (70%)]\t training loss: 0.040492\n",
      "epoch: 1 [42240/60000 (70%)]\t training loss: 0.011083\n",
      "epoch: 1 [42560/60000 (71%)]\t training loss: 0.037179\n",
      "epoch: 1 [42880/60000 (71%)]\t training loss: 0.100338\n",
      "epoch: 1 [43200/60000 (72%)]\t training loss: 0.045283\n",
      "epoch: 1 [43520/60000 (73%)]\t training loss: 0.226100\n",
      "epoch: 1 [43840/60000 (73%)]\t training loss: 0.040024\n",
      "epoch: 1 [44160/60000 (74%)]\t training loss: 0.116498\n",
      "epoch: 1 [44480/60000 (74%)]\t training loss: 0.061347\n",
      "epoch: 1 [44800/60000 (75%)]\t training loss: 0.250636\n",
      "epoch: 1 [45120/60000 (75%)]\t training loss: 0.015089\n",
      "epoch: 1 [45440/60000 (76%)]\t training loss: 0.179538\n",
      "epoch: 1 [45760/60000 (76%)]\t training loss: 0.005873\n",
      "epoch: 1 [46080/60000 (77%)]\t training loss: 0.246703\n",
      "epoch: 1 [46400/60000 (77%)]\t training loss: 0.059748\n",
      "epoch: 1 [46720/60000 (78%)]\t training loss: 0.086515\n",
      "epoch: 1 [47040/60000 (78%)]\t training loss: 0.056783\n",
      "epoch: 1 [47360/60000 (79%)]\t training loss: 0.026905\n",
      "epoch: 1 [47680/60000 (79%)]\t training loss: 0.229662\n",
      "epoch: 1 [48000/60000 (80%)]\t training loss: 0.219784\n",
      "epoch: 1 [48320/60000 (81%)]\t training loss: 0.173513\n",
      "epoch: 1 [48640/60000 (81%)]\t training loss: 0.020131\n",
      "epoch: 1 [48960/60000 (82%)]\t training loss: 0.092622\n",
      "epoch: 1 [49280/60000 (82%)]\t training loss: 0.143768\n",
      "epoch: 1 [49600/60000 (83%)]\t training loss: 0.227249\n",
      "epoch: 1 [49920/60000 (83%)]\t training loss: 0.129910\n",
      "epoch: 1 [50240/60000 (84%)]\t training loss: 0.071571\n",
      "epoch: 1 [50560/60000 (84%)]\t training loss: 0.003748\n",
      "epoch: 1 [50880/60000 (85%)]\t training loss: 0.004240\n",
      "epoch: 1 [51200/60000 (85%)]\t training loss: 0.232011\n",
      "epoch: 1 [51520/60000 (86%)]\t training loss: 0.028602\n",
      "epoch: 1 [51840/60000 (86%)]\t training loss: 0.013641\n",
      "epoch: 1 [52160/60000 (87%)]\t training loss: 0.010407\n",
      "epoch: 1 [52480/60000 (87%)]\t training loss: 0.014092\n",
      "epoch: 1 [52800/60000 (88%)]\t training loss: 0.041886\n",
      "epoch: 1 [53120/60000 (89%)]\t training loss: 0.029681\n",
      "epoch: 1 [53440/60000 (89%)]\t training loss: 0.032927\n",
      "epoch: 1 [53760/60000 (90%)]\t training loss: 0.044958\n",
      "epoch: 1 [54080/60000 (90%)]\t training loss: 0.014901\n",
      "epoch: 1 [54400/60000 (91%)]\t training loss: 0.136652\n",
      "epoch: 1 [54720/60000 (91%)]\t training loss: 0.011173\n",
      "epoch: 1 [55040/60000 (92%)]\t training loss: 0.002030\n",
      "epoch: 1 [55360/60000 (92%)]\t training loss: 0.066842\n",
      "epoch: 1 [55680/60000 (93%)]\t training loss: 0.048196\n",
      "epoch: 1 [56000/60000 (93%)]\t training loss: 0.052259\n",
      "epoch: 1 [56320/60000 (94%)]\t training loss: 0.161982\n",
      "epoch: 1 [56640/60000 (94%)]\t training loss: 0.054209\n",
      "epoch: 1 [56960/60000 (95%)]\t training loss: 0.028114\n",
      "epoch: 1 [57280/60000 (95%)]\t training loss: 0.107114\n",
      "epoch: 1 [57600/60000 (96%)]\t training loss: 0.003919\n",
      "epoch: 1 [57920/60000 (97%)]\t training loss: 0.093449\n",
      "epoch: 1 [58240/60000 (97%)]\t training loss: 0.043330\n",
      "epoch: 1 [58560/60000 (98%)]\t training loss: 0.005689\n",
      "epoch: 1 [58880/60000 (98%)]\t training loss: 0.087195\n",
      "epoch: 1 [59200/60000 (99%)]\t training loss: 0.009148\n",
      "epoch: 1 [59520/60000 (99%)]\t training loss: 0.035150\n",
      "epoch: 1 [59840/60000 (100%)]\t training loss: 0.024602\n",
      "\n",
      "Test dataset: Overall Loss: 0.0478, Overall Accuracy: 9845/10000 (98%)\n",
      "\n",
      "epoch: 2 [0/60000 (0%)]\t training loss: 0.075755\n",
      "epoch: 2 [320/60000 (1%)]\t training loss: 0.052196\n",
      "epoch: 2 [640/60000 (1%)]\t training loss: 0.176723\n",
      "epoch: 2 [960/60000 (2%)]\t training loss: 0.084691\n",
      "epoch: 2 [1280/60000 (2%)]\t training loss: 0.037865\n",
      "epoch: 2 [1600/60000 (3%)]\t training loss: 0.002995\n",
      "epoch: 2 [1920/60000 (3%)]\t training loss: 0.085839\n",
      "epoch: 2 [2240/60000 (4%)]\t training loss: 0.028786\n",
      "epoch: 2 [2560/60000 (4%)]\t training loss: 0.134300\n",
      "epoch: 2 [2880/60000 (5%)]\t training loss: 0.009706\n",
      "epoch: 2 [3200/60000 (5%)]\t training loss: 0.072933\n",
      "epoch: 2 [3520/60000 (6%)]\t training loss: 0.010232\n",
      "epoch: 2 [3840/60000 (6%)]\t training loss: 0.069974\n",
      "epoch: 2 [4160/60000 (7%)]\t training loss: 0.010989\n",
      "epoch: 2 [4480/60000 (7%)]\t training loss: 0.037965\n",
      "epoch: 2 [4800/60000 (8%)]\t training loss: 0.048734\n",
      "epoch: 2 [5120/60000 (9%)]\t training loss: 0.174570\n",
      "epoch: 2 [5440/60000 (9%)]\t training loss: 0.169110\n",
      "epoch: 2 [5760/60000 (10%)]\t training loss: 0.078296\n",
      "epoch: 2 [6080/60000 (10%)]\t training loss: 0.047879\n",
      "epoch: 2 [6400/60000 (11%)]\t training loss: 0.039148\n",
      "epoch: 2 [6720/60000 (11%)]\t training loss: 0.041429\n",
      "epoch: 2 [7040/60000 (12%)]\t training loss: 0.012437\n",
      "epoch: 2 [7360/60000 (12%)]\t training loss: 0.093218\n",
      "epoch: 2 [7680/60000 (13%)]\t training loss: 0.025234\n",
      "epoch: 2 [8000/60000 (13%)]\t training loss: 0.192158\n",
      "epoch: 2 [8320/60000 (14%)]\t training loss: 0.038945\n",
      "epoch: 2 [8640/60000 (14%)]\t training loss: 0.018283\n",
      "epoch: 2 [8960/60000 (15%)]\t training loss: 0.026005\n",
      "epoch: 2 [9280/60000 (15%)]\t training loss: 0.002971\n",
      "epoch: 2 [9600/60000 (16%)]\t training loss: 0.055324\n",
      "epoch: 2 [9920/60000 (17%)]\t training loss: 0.003880\n",
      "epoch: 2 [10240/60000 (17%)]\t training loss: 0.005629\n",
      "epoch: 2 [10560/60000 (18%)]\t training loss: 0.001743\n",
      "epoch: 2 [10880/60000 (18%)]\t training loss: 0.063067\n",
      "epoch: 2 [11200/60000 (19%)]\t training loss: 0.034985\n",
      "epoch: 2 [11520/60000 (19%)]\t training loss: 0.007477\n",
      "epoch: 2 [11840/60000 (20%)]\t training loss: 0.091652\n",
      "epoch: 2 [12160/60000 (20%)]\t training loss: 0.057981\n",
      "epoch: 2 [12480/60000 (21%)]\t training loss: 0.045031\n",
      "epoch: 2 [12800/60000 (21%)]\t training loss: 0.006746\n",
      "epoch: 2 [13120/60000 (22%)]\t training loss: 0.151338\n",
      "epoch: 2 [13440/60000 (22%)]\t training loss: 0.051468\n",
      "epoch: 2 [13760/60000 (23%)]\t training loss: 0.118806\n",
      "epoch: 2 [14080/60000 (23%)]\t training loss: 0.059753\n",
      "epoch: 2 [14400/60000 (24%)]\t training loss: 0.023917\n",
      "epoch: 2 [14720/60000 (25%)]\t training loss: 0.010549\n",
      "epoch: 2 [15040/60000 (25%)]\t training loss: 0.001766\n",
      "epoch: 2 [15360/60000 (26%)]\t training loss: 0.083428\n",
      "epoch: 2 [15680/60000 (26%)]\t training loss: 0.060596\n",
      "epoch: 2 [16000/60000 (27%)]\t training loss: 0.054066\n",
      "epoch: 2 [16320/60000 (27%)]\t training loss: 0.155204\n",
      "epoch: 2 [16640/60000 (28%)]\t training loss: 0.007297\n",
      "epoch: 2 [16960/60000 (28%)]\t training loss: 0.391733\n",
      "epoch: 2 [17280/60000 (29%)]\t training loss: 0.041596\n",
      "epoch: 2 [17600/60000 (29%)]\t training loss: 0.058965\n",
      "epoch: 2 [17920/60000 (30%)]\t training loss: 0.140119\n",
      "epoch: 2 [18240/60000 (30%)]\t training loss: 0.083102\n",
      "epoch: 2 [18560/60000 (31%)]\t training loss: 0.001707\n",
      "epoch: 2 [18880/60000 (31%)]\t training loss: 0.050675\n",
      "epoch: 2 [19200/60000 (32%)]\t training loss: 0.163005\n",
      "epoch: 2 [19520/60000 (33%)]\t training loss: 0.006558\n",
      "epoch: 2 [19840/60000 (33%)]\t training loss: 0.071730\n",
      "epoch: 2 [20160/60000 (34%)]\t training loss: 0.021590\n",
      "epoch: 2 [20480/60000 (34%)]\t training loss: 0.116760\n",
      "epoch: 2 [20800/60000 (35%)]\t training loss: 0.007973\n",
      "epoch: 2 [21120/60000 (35%)]\t training loss: 0.081753\n",
      "epoch: 2 [21440/60000 (36%)]\t training loss: 0.031363\n",
      "epoch: 2 [21760/60000 (36%)]\t training loss: 0.012533\n",
      "epoch: 2 [22080/60000 (37%)]\t training loss: 0.007146\n",
      "epoch: 2 [22400/60000 (37%)]\t training loss: 0.002010\n",
      "epoch: 2 [22720/60000 (38%)]\t training loss: 0.007648\n",
      "epoch: 2 [23040/60000 (38%)]\t training loss: 0.105152\n",
      "epoch: 2 [23360/60000 (39%)]\t training loss: 0.101917\n",
      "epoch: 2 [23680/60000 (39%)]\t training loss: 0.163842\n",
      "epoch: 2 [24000/60000 (40%)]\t training loss: 0.034108\n",
      "epoch: 2 [24320/60000 (41%)]\t training loss: 0.024196\n",
      "epoch: 2 [24640/60000 (41%)]\t training loss: 0.089241\n",
      "epoch: 2 [24960/60000 (42%)]\t training loss: 0.051603\n",
      "epoch: 2 [25280/60000 (42%)]\t training loss: 0.054226\n",
      "epoch: 2 [25600/60000 (43%)]\t training loss: 0.058934\n",
      "epoch: 2 [25920/60000 (43%)]\t training loss: 0.028085\n",
      "epoch: 2 [26240/60000 (44%)]\t training loss: 0.003685\n",
      "epoch: 2 [26560/60000 (44%)]\t training loss: 0.005134\n",
      "epoch: 2 [26880/60000 (45%)]\t training loss: 0.069611\n",
      "epoch: 2 [27200/60000 (45%)]\t training loss: 0.453698\n",
      "epoch: 2 [27520/60000 (46%)]\t training loss: 0.077049\n",
      "epoch: 2 [27840/60000 (46%)]\t training loss: 0.003867\n",
      "epoch: 2 [28160/60000 (47%)]\t training loss: 0.049487\n",
      "epoch: 2 [28480/60000 (47%)]\t training loss: 0.097699\n",
      "epoch: 2 [28800/60000 (48%)]\t training loss: 0.013503\n",
      "epoch: 2 [29120/60000 (49%)]\t training loss: 0.003477\n",
      "epoch: 2 [29440/60000 (49%)]\t training loss: 0.046258\n",
      "epoch: 2 [29760/60000 (50%)]\t training loss: 0.057281\n",
      "epoch: 2 [30080/60000 (50%)]\t training loss: 0.017527\n",
      "epoch: 2 [30400/60000 (51%)]\t training loss: 0.025550\n",
      "epoch: 2 [30720/60000 (51%)]\t training loss: 0.027727\n",
      "epoch: 2 [31040/60000 (52%)]\t training loss: 0.170225\n",
      "epoch: 2 [31360/60000 (52%)]\t training loss: 0.108750\n",
      "epoch: 2 [31680/60000 (53%)]\t training loss: 0.011146\n",
      "epoch: 2 [32000/60000 (53%)]\t training loss: 0.003069\n",
      "epoch: 2 [32320/60000 (54%)]\t training loss: 0.006169\n",
      "epoch: 2 [32640/60000 (54%)]\t training loss: 0.024151\n",
      "epoch: 2 [32960/60000 (55%)]\t training loss: 0.001360\n",
      "epoch: 2 [33280/60000 (55%)]\t training loss: 0.147938\n",
      "epoch: 2 [33600/60000 (56%)]\t training loss: 0.015906\n",
      "epoch: 2 [33920/60000 (57%)]\t training loss: 0.013825\n",
      "epoch: 2 [34240/60000 (57%)]\t training loss: 0.330136\n",
      "epoch: 2 [34560/60000 (58%)]\t training loss: 0.005314\n",
      "epoch: 2 [34880/60000 (58%)]\t training loss: 0.150999\n",
      "epoch: 2 [35200/60000 (59%)]\t training loss: 0.025108\n",
      "epoch: 2 [35520/60000 (59%)]\t training loss: 0.020267\n",
      "epoch: 2 [35840/60000 (60%)]\t training loss: 0.003014\n",
      "epoch: 2 [36160/60000 (60%)]\t training loss: 0.040846\n",
      "epoch: 2 [36480/60000 (61%)]\t training loss: 0.092544\n",
      "epoch: 2 [36800/60000 (61%)]\t training loss: 0.002039\n",
      "epoch: 2 [37120/60000 (62%)]\t training loss: 0.433007\n",
      "epoch: 2 [37440/60000 (62%)]\t training loss: 0.005830\n",
      "epoch: 2 [37760/60000 (63%)]\t training loss: 0.054115\n",
      "epoch: 2 [38080/60000 (63%)]\t training loss: 0.018357\n",
      "epoch: 2 [38400/60000 (64%)]\t training loss: 0.050539\n",
      "epoch: 2 [38720/60000 (65%)]\t training loss: 0.087123\n",
      "epoch: 2 [39040/60000 (65%)]\t training loss: 0.014614\n",
      "epoch: 2 [39360/60000 (66%)]\t training loss: 0.005193\n",
      "epoch: 2 [39680/60000 (66%)]\t training loss: 0.177055\n",
      "epoch: 2 [40000/60000 (67%)]\t training loss: 0.007170\n",
      "epoch: 2 [40320/60000 (67%)]\t training loss: 0.240168\n",
      "epoch: 2 [40640/60000 (68%)]\t training loss: 0.587453\n",
      "epoch: 2 [40960/60000 (68%)]\t training loss: 0.033368\n",
      "epoch: 2 [41280/60000 (69%)]\t training loss: 0.106131\n",
      "epoch: 2 [41600/60000 (69%)]\t training loss: 0.142523\n",
      "epoch: 2 [41920/60000 (70%)]\t training loss: 0.005451\n",
      "epoch: 2 [42240/60000 (70%)]\t training loss: 0.006553\n",
      "epoch: 2 [42560/60000 (71%)]\t training loss: 0.057431\n",
      "epoch: 2 [42880/60000 (71%)]\t training loss: 0.008507\n",
      "epoch: 2 [43200/60000 (72%)]\t training loss: 0.002467\n",
      "epoch: 2 [43520/60000 (73%)]\t training loss: 0.038520\n",
      "epoch: 2 [43840/60000 (73%)]\t training loss: 0.240774\n",
      "epoch: 2 [44160/60000 (74%)]\t training loss: 0.177386\n",
      "epoch: 2 [44480/60000 (74%)]\t training loss: 0.010960\n",
      "epoch: 2 [44800/60000 (75%)]\t training loss: 0.004259\n",
      "epoch: 2 [45120/60000 (75%)]\t training loss: 0.069959\n",
      "epoch: 2 [45440/60000 (76%)]\t training loss: 0.044531\n",
      "epoch: 2 [45760/60000 (76%)]\t training loss: 0.046592\n",
      "epoch: 2 [46080/60000 (77%)]\t training loss: 0.084542\n",
      "epoch: 2 [46400/60000 (77%)]\t training loss: 0.021363\n",
      "epoch: 2 [46720/60000 (78%)]\t training loss: 0.160489\n",
      "epoch: 2 [47040/60000 (78%)]\t training loss: 0.038613\n",
      "epoch: 2 [47360/60000 (79%)]\t training loss: 0.050239\n",
      "epoch: 2 [47680/60000 (79%)]\t training loss: 0.027921\n",
      "epoch: 2 [48000/60000 (80%)]\t training loss: 0.073250\n",
      "epoch: 2 [48320/60000 (81%)]\t training loss: 0.183588\n",
      "epoch: 2 [48640/60000 (81%)]\t training loss: 0.045946\n",
      "epoch: 2 [48960/60000 (82%)]\t training loss: 0.006745\n",
      "epoch: 2 [49280/60000 (82%)]\t training loss: 0.003105\n",
      "epoch: 2 [49600/60000 (83%)]\t training loss: 0.010110\n",
      "epoch: 2 [49920/60000 (83%)]\t training loss: 0.116635\n",
      "epoch: 2 [50240/60000 (84%)]\t training loss: 0.039653\n",
      "epoch: 2 [50560/60000 (84%)]\t training loss: 0.005039\n",
      "epoch: 2 [50880/60000 (85%)]\t training loss: 0.006200\n",
      "epoch: 2 [51200/60000 (85%)]\t training loss: 0.013030\n",
      "epoch: 2 [51520/60000 (86%)]\t training loss: 0.213073\n",
      "epoch: 2 [51840/60000 (86%)]\t training loss: 0.005721\n",
      "epoch: 2 [52160/60000 (87%)]\t training loss: 0.003327\n",
      "epoch: 2 [52480/60000 (87%)]\t training loss: 0.117841\n",
      "epoch: 2 [52800/60000 (88%)]\t training loss: 0.050646\n",
      "epoch: 2 [53120/60000 (89%)]\t training loss: 0.005108\n",
      "epoch: 2 [53440/60000 (89%)]\t training loss: 0.083819\n",
      "epoch: 2 [53760/60000 (90%)]\t training loss: 0.174583\n",
      "epoch: 2 [54080/60000 (90%)]\t training loss: 0.003345\n",
      "epoch: 2 [54400/60000 (91%)]\t training loss: 0.393704\n",
      "epoch: 2 [54720/60000 (91%)]\t training loss: 0.115101\n",
      "epoch: 2 [55040/60000 (92%)]\t training loss: 0.018684\n",
      "epoch: 2 [55360/60000 (92%)]\t training loss: 0.072076\n",
      "epoch: 2 [55680/60000 (93%)]\t training loss: 0.012095\n",
      "epoch: 2 [56000/60000 (93%)]\t training loss: 0.027326\n",
      "epoch: 2 [56320/60000 (94%)]\t training loss: 0.001306\n",
      "epoch: 2 [56640/60000 (94%)]\t training loss: 0.058052\n",
      "epoch: 2 [56960/60000 (95%)]\t training loss: 0.012479\n",
      "epoch: 2 [57280/60000 (95%)]\t training loss: 0.077006\n",
      "epoch: 2 [57600/60000 (96%)]\t training loss: 0.074425\n",
      "epoch: 2 [57920/60000 (97%)]\t training loss: 0.150789\n",
      "epoch: 2 [58240/60000 (97%)]\t training loss: 0.051848\n",
      "epoch: 2 [58560/60000 (98%)]\t training loss: 0.003488\n",
      "epoch: 2 [58880/60000 (98%)]\t training loss: 0.003011\n",
      "epoch: 2 [59200/60000 (99%)]\t training loss: 0.025069\n",
      "epoch: 2 [59520/60000 (99%)]\t training loss: 0.021845\n",
      "epoch: 2 [59840/60000 (100%)]\t training loss: 0.008774\n",
      "\n",
      "Test dataset: Overall Loss: 0.0421, Overall Accuracy: 9857/10000 (99%)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(1, 3):\n",
    "    train(model, device, train_dataloader, optimizer, epoch)\n",
    "    test(model, device, test_dataloader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## run inference on trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAaqElEQVR4nO3df2xV9f3H8VeL9ILaXiylvb2jQEEFwy8ng9rwYygNtC4GtEtA/QMWAoFdzLDzx7qIKFvSjSWOuCD+s8BMxF+JQCRLMym2hNliqDDCph3tugGBFsVxbylSGP18/yDer1cKeMq9ffdeno/kJPTe8+l9ezzhyWlvT9Occ04AAPSxdOsBAAA3JwIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBM3GI9wLd1d3frxIkTyszMVFpamvU4AACPnHPq6OhQMBhUevrVr3P6XYBOnDihgoIC6zEAADfo2LFjGj58+FWf73dfgsvMzLQeAQAQB9f7+zxhAdq4caNGjRqlQYMGqaioSB9//PF3WseX3QAgNVzv7/OEBOjtt99WRUWF1q5dq08++USTJ0/WvHnzdOrUqUS8HAAgGbkEmDZtmguFQtGPL1265ILBoKuqqrru2nA47CSxsbGxsSX5Fg6Hr/n3fdyvgC5cuKDGxkaVlJREH0tPT1dJSYnq6+uv2L+rq0uRSCRmAwCkvrgH6IsvvtClS5eUl5cX83heXp7a2tqu2L+qqkp+vz+68Q44ALg5mL8LrrKyUuFwOLodO3bMeiQAQB+I+88B5eTkaMCAAWpvb495vL29XYFA4Ir9fT6ffD5fvMcAAPRzcb8CysjI0JQpU1RTUxN9rLu7WzU1NSouLo73ywEAklRC7oRQUVGhxYsX6wc/+IGmTZumDRs2qLOzUz/5yU8S8XIAgCSUkAAtXLhQn3/+uV544QW1tbXp3nvvVXV19RVvTAAA3LzSnHPOeohvikQi8vv91mMAAG5QOBxWVlbWVZ83fxccAODmRIAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATMQ9QC+++KLS0tJitnHjxsX7ZQAASe6WRHzS8ePHa9euXf//Irck5GUAAEksIWW45ZZbFAgEEvGpAQApIiHfAzpy5IiCwaBGjx6tJ554QkePHr3qvl1dXYpEIjEbACD1xT1ARUVF2rJli6qrq7Vp0ya1trZq5syZ6ujo6HH/qqoq+f3+6FZQUBDvkQAA/VCac84l8gXOnDmjkSNH6uWXX9bSpUuveL6rq0tdXV3RjyORCBECgBQQDoeVlZV11ecT/u6AIUOG6O6771Zzc3OPz/t8Pvl8vkSPAQDoZxL+c0Bnz55VS0uL8vPzE/1SAIAkEvcAPf3006qrq9O///1vffTRR3rkkUc0YMAAPfbYY/F+KQBAEov7l+COHz+uxx57TKdPn9awYcM0Y8YMNTQ0aNiwYfF+KQBAEkv4mxC8ikQi8vv91mMAAG7Q9d6EwL3gAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATCf+FdOhbP/7xjz2vWbZsWa9e68SJE57XnD9/3vOaN954w/OatrY2z2skXfUXJwKIP66AAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYCLNOeesh/imSCQiv99vPUbS+te//uV5zahRo+I/iLGOjo5erfv73/8e50kQb8ePH/e8Zv369b16rf379/dqHS4Lh8PKysq66vNcAQEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJm6xHgDxtWzZMs9rJk2a1KvX+vTTTz2vueeeezyvue+++zyvmT17tuc1knT//fd7XnPs2DHPawoKCjyv6Uv/+9//PK/5/PPPPa/Jz8/3vKY3jh492qt13Iw0sbgCAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMcDPSFFNTU9Mna3qrurq6T17njjvu6NW6e++91/OaxsZGz2umTp3qeU1fOn/+vOc1//znPz2v6c0NbbOzsz2vaWlp8bwGiccVEADABAECAJjwHKA9e/bo4YcfVjAYVFpamrZv3x7zvHNOL7zwgvLz8zV48GCVlJToyJEj8ZoXAJAiPAeos7NTkydP1saNG3t8fv369XrllVf02muvad++fbrttts0b968Xn1NGQCQujy/CaGsrExlZWU9Puec04YNG/T8889r/vz5kqTXX39deXl52r59uxYtWnRj0wIAUkZcvwfU2tqqtrY2lZSURB/z+/0qKipSfX19j2u6uroUiURiNgBA6otrgNra2iRJeXl5MY/n5eVFn/u2qqoq+f3+6FZQUBDPkQAA/ZT5u+AqKysVDoej27Fjx6xHAgD0gbgGKBAISJLa29tjHm9vb48+920+n09ZWVkxGwAg9cU1QIWFhQoEAjE/WR+JRLRv3z4VFxfH86UAAEnO87vgzp49q+bm5ujHra2tOnjwoLKzszVixAitXr1av/71r3XXXXepsLBQa9asUTAY1IIFC+I5NwAgyXkO0P79+/XAAw9EP66oqJAkLV68WFu2bNGzzz6rzs5OLV++XGfOnNGMGTNUXV2tQYMGxW9qAEDSS3POOeshvikSicjv91uPAcCj8vJyz2veeecdz2sOHz7sec03/9HsxZdfftmrdbgsHA5f8/v65u+CAwDcnAgQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGDC869jAJD6cnNzPa959dVXPa9JT/f+b+B169Z5XsNdrfsnroAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABPcjBTAFUKhkOc1w4YN87zmv//9r+c1TU1Nntegf+IKCABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwwc1IgRQ2ffr0Xq37xS9+EedJerZgwQLPaw4fPhz/QWCCKyAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQ3IwVS2EMPPdSrdQMHDvS8pqamxvOa+vp6z2uQOrgCAgCYIEAAABOeA7Rnzx49/PDDCgaDSktL0/bt22OeX7JkidLS0mK20tLSeM0LAEgRngPU2dmpyZMna+PGjVfdp7S0VCdPnoxub7755g0NCQBIPZ7fhFBWVqaysrJr7uPz+RQIBHo9FAAg9SXke0C1tbXKzc3V2LFjtXLlSp0+ffqq+3Z1dSkSicRsAIDUF/cAlZaW6vXXX1dNTY1++9vfqq6uTmVlZbp06VKP+1dVVcnv90e3goKCeI8EAOiH4v5zQIsWLYr+eeLEiZo0aZLGjBmj2tpazZkz54r9KysrVVFREf04EokQIQC4CST8bdijR49WTk6Ompube3ze5/MpKysrZgMApL6EB+j48eM6ffq08vPzE/1SAIAk4vlLcGfPno25mmltbdXBgweVnZ2t7OxsvfTSSyovL1cgEFBLS4ueffZZ3XnnnZo3b15cBwcAJDfPAdq/f78eeOCB6Mdff/9m8eLF2rRpkw4dOqQ//elPOnPmjILBoObOnatf/epX8vl88ZsaAJD00pxzznqIb4pEIvL7/dZjAP3O4MGDPa/Zu3dvr15r/Pjxntc8+OCDntd89NFHntcgeYTD4Wt+X597wQEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMBE3H8lN4DEeOaZZzyv+f73v9+r16qurva8hjtbwyuugAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAE9yMFDDwox/9yPOaNWvWeF4TiUQ8r5GkdevW9Wod4AVXQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACW5GCtygoUOHel7zyiuveF4zYMAAz2v+/Oc/e14jSQ0NDb1aB3jBFRAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIKbkQLf0JsbflZXV3teU1hY6HlNS0uL5zVr1qzxvAboK1wBAQBMECAAgAlPAaqqqtLUqVOVmZmp3NxcLViwQE1NTTH7nD9/XqFQSEOHDtXtt9+u8vJytbe3x3VoAEDy8xSguro6hUIhNTQ06IMPPtDFixc1d+5cdXZ2Rvd56qmn9P777+vdd99VXV2dTpw4oUcffTTugwMAkpunNyF8+5utW7ZsUW5urhobGzVr1iyFw2H98Y9/1NatW/Xggw9KkjZv3qx77rlHDQ0Nuv/+++M3OQAgqd3Q94DC4bAkKTs7W5LU2NioixcvqqSkJLrPuHHjNGLECNXX1/f4Obq6uhSJRGI2AEDq63WAuru7tXr1ak2fPl0TJkyQJLW1tSkjI0NDhgyJ2TcvL09tbW09fp6qqir5/f7oVlBQ0NuRAABJpNcBCoVCOnz4sN56660bGqCyslLhcDi6HTt27IY+HwAgOfTqB1FXrVqlnTt3as+ePRo+fHj08UAgoAsXLujMmTMxV0Ht7e0KBAI9fi6fzyefz9ebMQAASczTFZBzTqtWrdK2bdu0e/fuK36ae8qUKRo4cKBqamqijzU1Neno0aMqLi6Oz8QAgJTg6QooFApp69at2rFjhzIzM6Pf1/H7/Ro8eLD8fr+WLl2qiooKZWdnKysrS08++aSKi4t5BxwAIIanAG3atEmSNHv27JjHN2/erCVLlkiSfv/73ys9PV3l5eXq6urSvHnz9Oqrr8ZlWABA6khzzjnrIb4pEonI7/dbj4Gb1N133+15zWeffZaASa40f/58z2vef//9BEwCfDfhcFhZWVlXfZ57wQEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMBEr34jKtDfjRw5slfr/vKXv8R5kp4988wzntfs3LkzAZMAdrgCAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMcDNSpKTly5f3at2IESPiPEnP6urqPK9xziVgEsAOV0AAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAluRop+b8aMGZ7XPPnkkwmYBEA8cQUEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJjgZqTo92bOnOl5ze23356ASXrW0tLiec3Zs2cTMAmQXLgCAgCYIEAAABOeAlRVVaWpU6cqMzNTubm5WrBggZqammL2mT17ttLS0mK2FStWxHVoAEDy8xSguro6hUIhNTQ06IMPPtDFixc1d+5cdXZ2xuy3bNkynTx5MrqtX78+rkMDAJKfpzchVFdXx3y8ZcsW5ebmqrGxUbNmzYo+fuuttyoQCMRnQgBASrqh7wGFw2FJUnZ2dszjb7zxhnJycjRhwgRVVlbq3LlzV/0cXV1dikQiMRsAIPX1+m3Y3d3dWr16taZPn64JEyZEH3/88cc1cuRIBYNBHTp0SM8995yampr03nvv9fh5qqqq9NJLL/V2DABAkup1gEKhkA4fPqy9e/fGPL58+fLonydOnKj8/HzNmTNHLS0tGjNmzBWfp7KyUhUVFdGPI5GICgoKejsWACBJ9CpAq1at0s6dO7Vnzx4NHz78mvsWFRVJkpqbm3sMkM/nk8/n680YAIAk5ilAzjk9+eST2rZtm2pra1VYWHjdNQcPHpQk5efn92pAAEBq8hSgUCikrVu3aseOHcrMzFRbW5skye/3a/DgwWppadHWrVv10EMPaejQoTp06JCeeuopzZo1S5MmTUrIfwAAIDl5CtCmTZskXf5h02/avHmzlixZooyMDO3atUsbNmxQZ2enCgoKVF5erueffz5uAwMAUoPnL8FdS0FBgerq6m5oIADAzYG7YQPf8Le//c3zmjlz5nhe8+WXX3peA6QabkYKADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJhIc9e7xXUfi0Qi8vv91mMAAG5QOBxWVlbWVZ/nCggAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAICJfhegfnZrOgBAL13v7/N+F6COjg7rEQAAcXC9v8/73d2wu7u7deLECWVmZiotLS3muUgkooKCAh07duyad1hNdRyHyzgOl3EcLuM4XNYfjoNzTh0dHQoGg0pPv/p1zi19ONN3kp6eruHDh19zn6ysrJv6BPsax+EyjsNlHIfLOA6XWR+H7/Jrdfrdl+AAADcHAgQAMJFUAfL5fFq7dq18Pp/1KKY4DpdxHC7jOFzGcbgsmY5Dv3sTAgDg5pBUV0AAgNRBgAAAJggQAMAEAQIAmEiaAG3cuFGjRo3SoEGDVFRUpI8//th6pD734osvKi0tLWYbN26c9VgJt2fPHj388MMKBoNKS0vT9u3bY553zumFF15Qfn6+Bg8erJKSEh05csRm2AS63nFYsmTJFedHaWmpzbAJUlVVpalTpyozM1O5ublasGCBmpqaYvY5f/68QqGQhg4dqttvv13l5eVqb283mjgxvstxmD179hXnw4oVK4wm7llSBOjtt99WRUWF1q5dq08++USTJ0/WvHnzdOrUKevR+tz48eN18uTJ6LZ3717rkRKus7NTkydP1saNG3t8fv369XrllVf02muvad++fbrttts0b948nT9/vo8nTazrHQdJKi0tjTk/3nzzzT6cMPHq6uoUCoXU0NCgDz74QBcvXtTcuXPV2dkZ3eepp57S+++/r3fffVd1dXU6ceKEHn30UcOp4++7HAdJWrZsWcz5sH79eqOJr8IlgWnTprlQKBT9+NKlSy4YDLqqqirDqfre2rVr3eTJk63HMCXJbdu2Lfpxd3e3CwQC7ne/+130sTNnzjifz+fefPNNgwn7xrePg3POLV682M2fP99kHiunTp1yklxdXZ1z7vL/+4EDB7p33303us+nn37qJLn6+nqrMRPu28fBOed++MMfup/97Gd2Q30H/f4K6MKFC2psbFRJSUn0sfT0dJWUlKi+vt5wMhtHjhxRMBjU6NGj9cQTT+jo0aPWI5lqbW1VW1tbzPnh9/tVVFR0U54ftbW1ys3N1dixY7Vy5UqdPn3aeqSECofDkqTs7GxJUmNjoy5evBhzPowbN04jRoxI6fPh28fha2+88YZycnI0YcIEVVZW6ty5cxbjXVW/uxnpt33xxRe6dOmS8vLyYh7Py8vTZ599ZjSVjaKiIm3ZskVjx47VyZMn9dJLL2nmzJk6fPiwMjMzrccz0dbWJkk9nh9fP3ezKC0t1aOPPqrCwkK1tLTol7/8pcrKylRfX68BAwZYjxd33d3dWr16taZPn64JEyZIunw+ZGRkaMiQITH7pvL50NNxkKTHH39cI0eOVDAY1KFDh/Tcc8+pqalJ7733nuG0sfp9gPD/ysrKon+eNGmSioqKNHLkSL3zzjtaunSp4WToDxYtWhT988SJEzVp0iSNGTNGtbW1mjNnjuFkiREKhXT48OGb4vug13K147B8+fLonydOnKj8/HzNmTNHLS0tGjNmTF+P2aN+/yW4nJwcDRgw4Ip3sbS3tysQCBhN1T8MGTJEd999t5qbm61HMfP1OcD5caXRo0crJycnJc+PVatWaefOnfrwww9jfn1LIBDQhQsXdObMmZj9U/V8uNpx6ElRUZEk9avzod8HKCMjQ1OmTFFNTU30se7ubtXU1Ki4uNhwMntnz55VS0uL8vPzrUcxU1hYqEAgEHN+RCIR7du376Y/P44fP67Tp0+n1PnhnNOqVau0bds27d69W4WFhTHPT5kyRQMHDow5H5qamnT06NGUOh+udxx6cvDgQUnqX+eD9bsgvou33nrL+Xw+t2XLFvePf/zDLV++3A0ZMsS1tbVZj9anfv7zn7va2lrX2trq/vrXv7qSkhKXk5PjTp06ZT1aQnV0dLgDBw64AwcOOEnu5ZdfdgcOHHD/+c9/nHPO/eY3v3FDhgxxO3bscIcOHXLz5893hYWF7quvvjKePL6udRw6Ojrc008/7err611ra6vbtWuXu++++9xdd93lzp8/bz163KxcudL5/X5XW1vrTp48Gd3OnTsX3WfFihVuxIgRbvfu3W7//v2uuLjYFRcXG04df9c7Ds3NzW7dunVu//79rrW11e3YscONHj3azZo1y3jyWEkRIOec+8Mf/uBGjBjhMjIy3LRp01xDQ4P1SH1u4cKFLj8/32VkZLjvfe97buHCha65udl6rIT78MMPnaQrtsWLFzvnLr8Ve82aNS4vL8/5fD43Z84c19TUZDt0AlzrOJw7d87NnTvXDRs2zA0cONCNHDnSLVu2LOX+kdbTf78kt3nz5ug+X331lfvpT3/q7rjjDnfrrbe6Rx55xJ08edJu6AS43nE4evSomzVrlsvOznY+n8/deeed7plnnnHhcNh28G/h1zEAAEz0++8BAQBSEwECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABg4v8AjVqFRqQZEfIAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "test_samples = enumerate(test_dataloader)\n",
    "b_i, (sample_data, sample_targets) = next(test_samples)\n",
    "\n",
    "plt.imshow(sample_data[0][0], cmap='gray', interpolation='none')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model prediction is : 7\n",
      "Ground truth is : 7\n"
     ]
    }
   ],
   "source": [
    "print(f\"Model prediction is : {model(sample_data).data.max(1)[1][0]}\")\n",
    "print(f\"Ground truth is : {sample_targets[0]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PROFILING STARTS HERE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ConvNet(\n",
      "  (cn1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1))\n",
      "  (cn2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))\n",
      "  (dp1): Dropout2d(p=0.1, inplace=False)\n",
      "  (dp2): Dropout2d(p=0.25, inplace=False)\n",
      "  (fc1): Linear(in_features=4608, out_features=64, bias=True)\n",
      "  (fc2): Linear(in_features=64, out_features=10, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([500, 1, 28, 28])\n"
     ]
    }
   ],
   "source": [
    "print(sample_data.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## profile model inference on cpu"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### cpu time"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### expand notebook to visualise traces properly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>.container { width:100% !important; }</style>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from IPython.display import display, HTML\n",
    "display(HTML(\"<style>.container { width:100% !important; }</style>\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.profiler import profile, record_function, ProfilerActivity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "STAGE:2024-02-16 16:03:39 3073744:3073744 ActivityProfilerController.cpp:314] Completed Stage: Warm Up\n",
      "STAGE:2024-02-16 16:03:39 3073744:3073744 ActivityProfilerController.cpp:320] Completed Stage: Collection\n",
      "STAGE:2024-02-16 16:03:39 3073744:3073744 ActivityProfilerController.cpp:324] Completed Stage: Post Processing\n"
     ]
    }
   ],
   "source": [
    "with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:\n",
    "    with record_function(\"model_inference\"):\n",
    "        model(sample_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  \n",
      "                             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  \n",
      "---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  \n",
      "                  model_inference        10.54%       4.908ms       100.00%      46.563ms      46.563ms             1  \n",
      "                     aten::conv2d         1.98%     920.000us        51.67%      24.057ms      12.028ms             2  \n",
      "                aten::convolution         0.19%      90.000us        49.69%      23.137ms      11.569ms             2  \n",
      "               aten::_convolution         0.11%      52.000us        49.50%      23.047ms      11.524ms             2  \n",
      "         aten::mkldnn_convolution        49.11%      22.865ms        49.38%      22.995ms      11.498ms             2  \n",
      "                 aten::max_pool2d         0.04%      19.000us        20.34%       9.473ms       9.473ms             1  \n",
      "    aten::max_pool2d_with_indices        20.30%       9.454ms        20.30%       9.454ms       9.454ms             1  \n",
      "                       aten::relu         0.28%     132.000us         9.35%       4.355ms       1.452ms             3  \n",
      "                  aten::clamp_min         9.07%       4.223ms         9.07%       4.223ms       1.408ms             3  \n",
      "                     aten::linear         0.05%      24.000us         5.79%       2.696ms       1.348ms             2  \n",
      "                      aten::addmm         5.44%       2.533ms         5.62%       2.617ms       1.308ms             2  \n",
      "                    aten::flatten         2.13%     992.000us         2.18%       1.016ms       1.016ms             1  \n",
      "                      aten::empty         0.20%      93.000us         0.20%      93.000us      23.250us             4  \n",
      "                      aten::copy_         0.15%      69.000us         0.15%      69.000us      34.500us             2  \n",
      "                aten::log_softmax         0.02%       7.000us         0.12%      56.000us      56.000us             1  \n",
      "                          aten::t         0.06%      29.000us         0.12%      55.000us      27.500us             2  \n",
      "               aten::_log_softmax         0.11%      49.000us         0.11%      49.000us      49.000us             1  \n",
      "                aten::as_strided_         0.07%      31.000us         0.07%      31.000us      15.500us             2  \n",
      "                  aten::transpose         0.04%      18.000us         0.06%      26.000us      13.000us             2  \n",
      "                       aten::view         0.05%      24.000us         0.05%      24.000us      24.000us             1  \n",
      "                     aten::expand         0.03%      12.000us         0.03%      15.000us       7.500us             2  \n",
      "                 aten::as_strided         0.02%      11.000us         0.02%      11.000us       2.750us             4  \n",
      "                    aten::resize_         0.01%       6.000us         0.01%       6.000us       3.000us             2  \n",
      "            aten::feature_dropout         0.00%       2.000us         0.00%       2.000us       1.000us             2  \n",
      "               aten::resolve_conj         0.00%       0.000us         0.00%       0.000us       0.000us             4  \n",
      "---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  \n",
      "Self CPU time total: 46.563ms\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(prof.key_averages().table(sort_by=\"cpu_time_total\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  --------------------------------------------------------------------------------  \n",
      "                             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls                                                                      Input Shapes  \n",
      "---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  --------------------------------------------------------------------------------  \n",
      "                  model_inference        10.54%       4.908ms       100.00%      46.563ms      46.563ms             1                                                                                []  \n",
      "                     aten::conv2d         0.03%      13.000us        36.59%      17.038ms      17.038ms             1                         [[500, 16, 26, 26], [32, 16, 3, 3], [32], [], [], [], []]  \n",
      "                aten::convolution         0.09%      42.000us        36.56%      17.025ms      17.025ms             1                 [[500, 16, 26, 26], [32, 16, 3, 3], [32], [], [], [], [], [], []]  \n",
      "               aten::_convolution         0.06%      30.000us        36.47%      16.983ms      16.983ms             1  [[500, 16, 26, 26], [32, 16, 3, 3], [32], [], [], [], [], [], [], [], [], [], []  \n",
      "         aten::mkldnn_convolution        36.26%      16.883ms        36.41%      16.953ms      16.953ms             1                         [[500, 16, 26, 26], [32, 16, 3, 3], [32], [], [], [], []]  \n",
      "                 aten::max_pool2d         0.04%      19.000us        20.34%       9.473ms       9.473ms             1                                           [[500, 32, 24, 24], [], [], [], [], []]  \n",
      "    aten::max_pool2d_with_indices        20.30%       9.454ms        20.30%       9.454ms       9.454ms             1                                           [[500, 32, 24, 24], [], [], [], [], []]  \n",
      "                     aten::conv2d         1.95%     907.000us        15.07%       7.019ms       7.019ms             1                           [[500, 1, 28, 28], [16, 1, 3, 3], [16], [], [], [], []]  \n",
      "                aten::convolution         0.10%      48.000us        13.13%       6.112ms       6.112ms             1                   [[500, 1, 28, 28], [16, 1, 3, 3], [16], [], [], [], [], [], []]  \n",
      "               aten::_convolution         0.05%      22.000us        13.02%       6.064ms       6.064ms             1   [[500, 1, 28, 28], [16, 1, 3, 3], [16], [], [], [], [], [], [], [], [], [], []]  \n",
      "         aten::mkldnn_convolution        12.85%       5.982ms        12.98%       6.042ms       6.042ms             1                           [[500, 1, 28, 28], [16, 1, 3, 3], [16], [], [], [], []]  \n",
      "                       aten::relu         0.14%      64.000us         7.48%       3.483ms       3.483ms             1                                                               [[500, 32, 24, 24]]  \n",
      "                  aten::clamp_min         7.34%       3.419ms         7.34%       3.419ms       3.419ms             1                                                           [[500, 32, 24, 24], []]  \n",
      "                     aten::linear         0.03%      15.000us         5.54%       2.579ms       2.579ms             1                                                   [[500, 4608], [64, 4608], [64]]  \n",
      "                      aten::addmm         5.32%       2.475ms         5.45%       2.539ms       2.539ms             1                                           [[64], [500, 4608], [4608, 64], [], []]  \n",
      "                    aten::flatten         2.13%     992.000us         2.18%       1.016ms       1.016ms             1                                                       [[500, 32, 12, 12], [], []]  \n",
      "                       aten::relu         0.10%      45.000us         1.74%     812.000us     812.000us             1                                                               [[500, 16, 26, 26]]  \n",
      "                  aten::clamp_min         1.65%     767.000us         1.65%     767.000us     767.000us             1                                                           [[500, 16, 26, 26], []]  \n",
      "                     aten::linear         0.02%       9.000us         0.25%     117.000us     117.000us             1                                                       [[500, 64], [10, 64], [10]]  \n",
      "                      aten::empty         0.20%      93.000us         0.20%      93.000us      23.250us             4                                                          [[], [], [], [], [], []]  \n",
      "                      aten::addmm         0.12%      58.000us         0.17%      78.000us      78.000us             1                                               [[10], [500, 64], [64, 10], [], []]  \n",
      "                       aten::relu         0.05%      23.000us         0.13%      60.000us      60.000us             1                                                                       [[500, 64]]  \n",
      "                aten::log_softmax         0.02%       7.000us         0.12%      56.000us      56.000us             1                                                               [[500, 10], [], []]  \n",
      "                      aten::copy_         0.12%      55.000us         0.12%      55.000us      55.000us             1                                                        [[500, 64], [500, 64], []]  \n",
      "               aten::_log_softmax         0.11%      49.000us         0.11%      49.000us      49.000us             1                                                               [[500, 10], [], []]  \n",
      "                  aten::clamp_min         0.08%      37.000us         0.08%      37.000us      37.000us             1                                                                   [[500, 64], []]  \n",
      "                          aten::t         0.03%      16.000us         0.06%      30.000us      30.000us             1                                                                        [[10, 64]]  \n",
      "                          aten::t         0.03%      13.000us         0.05%      25.000us      25.000us             1                                                                      [[64, 4608]]  \n",
      "                       aten::view         0.05%      24.000us         0.05%      24.000us      24.000us             1                                                           [[500, 32, 12, 12], []]  \n",
      "                aten::as_strided_         0.04%      17.000us         0.04%      17.000us      17.000us             1                                                   [[500, 16, 26, 26], [], [], []]  \n",
      "                aten::as_strided_         0.03%      14.000us         0.03%      14.000us      14.000us             1                                                   [[500, 32, 24, 24], [], [], []]  \n",
      "                  aten::transpose         0.02%      11.000us         0.03%      14.000us      14.000us             1                                                                [[10, 64], [], []]  \n",
      "                      aten::copy_         0.03%      14.000us         0.03%      14.000us      14.000us             1                                                        [[500, 10], [500, 10], []]  \n",
      "                  aten::transpose         0.02%       7.000us         0.03%      12.000us      12.000us             1                                                              [[64, 4608], [], []]  \n",
      "                     aten::expand         0.02%       7.000us         0.02%       9.000us       9.000us             1                                                                    [[64], [], []]  \n",
      "                     aten::expand         0.01%       5.000us         0.01%       6.000us       6.000us             1                                                                    [[10], [], []]  \n",
      "                 aten::as_strided         0.01%       5.000us         0.01%       5.000us       5.000us             1                                                          [[64, 4608], [], [], []]  \n",
      "                    aten::resize_         0.01%       3.000us         0.01%       3.000us       3.000us             1                                                       [[500, 16, 26, 26], [], []]  \n",
      "                    aten::resize_         0.01%       3.000us         0.01%       3.000us       3.000us             1                                                       [[500, 32, 24, 24], [], []]  \n",
      "                 aten::as_strided         0.01%       3.000us         0.01%       3.000us       3.000us             1                                                            [[10, 64], [], [], []]  \n",
      "            aten::feature_dropout         0.00%       2.000us         0.00%       2.000us       2.000us             1                                                       [[500, 32, 12, 12], [], []]  \n",
      "                 aten::as_strided         0.00%       2.000us         0.00%       2.000us       2.000us             1                                                                [[64], [], [], []]  \n",
      "                 aten::as_strided         0.00%       1.000us         0.00%       1.000us       1.000us             1                                                                [[10], [], [], []]  \n",
      "               aten::resolve_conj         0.00%       0.000us         0.00%       0.000us       0.000us             2                                                                       [[500, 64]]  \n",
      "               aten::resolve_conj         0.00%       0.000us         0.00%       0.000us       0.000us             1                                                                     [[500, 4608]]  \n",
      "            aten::feature_dropout         0.00%       0.000us         0.00%       0.000us       0.000us             1                                                               [[500, 64], [], []]  \n",
      "               aten::resolve_conj         0.00%       0.000us         0.00%       0.000us       0.000us             1                                                                       [[500, 10]]  \n",
      "---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  --------------------------------------------------------------------------------  \n",
      "Self CPU time total: 46.563ms\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(prof.key_averages(group_by_input_shape=True).table(sort_by=\"cpu_time_total\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### cpu memory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  \n",
      "                             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  \n",
      "---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  \n",
      "                       aten::relu         0.44%     172.000us        14.70%       5.740ms       1.913ms      55.91 Mb           0 b             3  \n",
      "                  aten::clamp_min        14.26%       5.568ms        14.26%       5.568ms       1.856ms      55.91 Mb      55.91 Mb             3  \n",
      "                     aten::conv2d         0.06%      22.000us        59.51%      23.232ms      11.616ms      55.79 Mb           0 b             2  \n",
      "                aten::convolution         0.20%      79.000us        59.45%      23.210ms      11.605ms      55.79 Mb           0 b             2  \n",
      "               aten::_convolution         0.12%      48.000us        59.25%      23.131ms      11.566ms      55.79 Mb           0 b             2  \n",
      "         aten::mkldnn_convolution        58.83%      22.966ms        59.13%      23.083ms      11.541ms      55.79 Mb           0 b             2  \n",
      "                      aten::empty         0.22%      85.000us         0.22%      85.000us      21.250us      55.79 Mb      55.79 Mb             4  \n",
      "                 aten::max_pool2d         0.05%      19.000us        22.08%       8.622ms       8.622ms      26.37 Mb           0 b             1  \n",
      "    aten::max_pool2d_with_indices        22.04%       8.603ms        22.04%       8.603ms       8.603ms      26.37 Mb      26.37 Mb             1  \n",
      "                     aten::linear         0.05%      18.000us         3.49%       1.364ms     682.000us     144.53 Kb           0 b             2  \n",
      "                      aten::addmm         3.17%       1.239ms         3.33%       1.300ms     650.000us     144.53 Kb     144.53 Kb             2  \n",
      "                aten::log_softmax         0.01%       5.000us         0.12%      45.000us      45.000us      19.53 Kb           0 b             1  \n",
      "               aten::_log_softmax         0.10%      40.000us         0.10%      40.000us      40.000us      19.53 Kb      19.53 Kb             1  \n",
      "                aten::as_strided_         0.07%      26.000us         0.07%      26.000us      13.000us           0 b           0 b             2  \n",
      "                    aten::resize_         0.02%       6.000us         0.02%       6.000us       3.000us           0 b           0 b             2  \n",
      "            aten::feature_dropout         0.01%       3.000us         0.01%       3.000us       1.500us           0 b           0 b             2  \n",
      "                    aten::flatten         0.04%      14.000us         0.09%      35.000us      35.000us           0 b           0 b             1  \n",
      "                       aten::view         0.05%      21.000us         0.05%      21.000us      21.000us           0 b           0 b             1  \n",
      "                          aten::t         0.06%      23.000us         0.12%      46.000us      23.000us           0 b           0 b             2  \n",
      "                  aten::transpose         0.04%      15.000us         0.06%      23.000us      11.500us           0 b           0 b             2  \n",
      "                 aten::as_strided         0.02%       8.000us         0.02%       8.000us       2.000us           0 b           0 b             4  \n",
      "                     aten::expand         0.04%      16.000us         0.04%      16.000us       8.000us           0 b           0 b             2  \n",
      "                      aten::copy_         0.12%      45.000us         0.12%      45.000us      22.500us           0 b           0 b             2  \n",
      "               aten::resolve_conj         0.00%       0.000us         0.00%       0.000us       0.000us           0 b           0 b             4  \n",
      "                         [memory]         0.00%       0.000us         0.00%       0.000us       0.000us    -138.22 Mb    -138.22 Mb            10  \n",
      "---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  \n",
      "Self CPU time total: 39.041ms\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "STAGE:2024-02-16 16:03:39 3073744:3073744 ActivityProfilerController.cpp:314] Completed Stage: Warm Up\n",
      "STAGE:2024-02-16 16:03:39 3073744:3073744 ActivityProfilerController.cpp:320] Completed Stage: Collection\n",
      "STAGE:2024-02-16 16:03:39 3073744:3073744 ActivityProfilerController.cpp:324] Completed Stage: Post Processing\n"
     ]
    }
   ],
   "source": [
    "with profile(activities=[ProfilerActivity.CPU],\n",
    "        profile_memory=True, record_shapes=True) as prof:\n",
    "    model(sample_data)\n",
    "\n",
    "print(prof.key_averages().table(sort_by=\"cpu_memory_usage\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### gpu time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "STAGE:2024-02-16 16:03:39 3073744:3073744 ActivityProfilerController.cpp:314] Completed Stage: Warm Up\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  \n",
      "                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  \n",
      "-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  \n",
      "                                        model_inference         0.04%       1.189ms        99.94%        3.252s        3.252s       0.000us         0.00%       3.935ms       3.935ms             1  \n",
      "                                           aten::conv2d         0.00%      27.000us        95.25%        3.099s        1.550s       0.000us         0.00%       3.024ms       1.512ms             2  \n",
      "                                      aten::convolution         0.00%      94.000us        95.25%        3.099s        1.550s       0.000us         0.00%       3.024ms       1.512ms             2  \n",
      "                                     aten::_convolution         0.03%       1.124ms        95.24%        3.099s        1.550s       0.000us         0.00%       3.024ms       1.512ms             2  \n",
      "                                aten::cudnn_convolution        17.15%     557.947ms        94.37%        3.071s        1.535s     745.000us        34.75%       2.536ms       1.268ms             2  \n",
      "                                 cudaDeviceGetAttribute         0.00%       3.000us         0.00%       3.000us       0.064us       1.791ms        83.54%       1.791ms      38.106us            47  \n",
      "             cudnn_volta_scudnn_128x32_relu_small_nn_v1         0.00%       0.000us         0.00%       0.000us       0.000us     538.000us        25.09%     538.000us     538.000us             1  \n",
      "                                             aten::relu         0.00%     100.000us         0.96%      31.293ms      10.431ms       0.000us         0.00%     493.000us     164.333us             3  \n",
      "                                        aten::clamp_min         0.08%       2.590ms         0.96%      31.193ms      10.398ms     493.000us        22.99%     493.000us     164.333us             3  \n",
      "void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     493.000us        22.99%     493.000us     164.333us             3  \n",
      "                                             aten::add_         0.00%     107.000us         0.84%      27.318ms      13.659ms     488.000us        22.76%     488.000us     244.000us             2  \n",
      "void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     488.000us        22.76%     488.000us     244.000us             2  \n",
      "                                       aten::max_pool2d         0.00%       8.000us         0.23%       7.442ms       7.442ms       0.000us         0.00%     272.000us     272.000us             1  \n",
      "                          aten::max_pool2d_with_indices         0.05%       1.480ms         0.23%       7.434ms       7.434ms     272.000us        12.69%     272.000us     272.000us             1  \n",
      "void at::native::(anonymous namespace)::max_pool_for...         0.00%       0.000us         0.00%       0.000us       0.000us     272.000us        12.69%     272.000us     272.000us             1  \n",
      "void implicit_convolve_sgemm<float, float, 1024, 5, ...         0.00%       0.000us         0.00%       0.000us       0.000us     203.000us         9.47%     203.000us     203.000us             1  \n",
      "                                           aten::linear         0.00%      20.000us         2.25%      73.206ms      36.603ms       0.000us         0.00%     144.000us      72.000us             2  \n",
      "                                            aten::addmm         0.68%      22.122ms         2.25%      73.134ms      36.567ms     144.000us         6.72%     144.000us      72.000us             2  \n",
      "                                  volta_sgemm_32x128_tn         0.00%       0.000us         0.00%       0.000us       0.000us     136.000us         6.34%     136.000us     136.000us             1  \n",
      "                         volta_sgemm_32x32_sliced1x4_tn         0.00%       0.000us         0.00%       0.000us       0.000us       7.000us         0.33%       7.000us       7.000us             1  \n",
      "                                        Memset (Device)         0.00%       0.000us         0.00%       0.000us       0.000us       3.000us         0.14%       3.000us       1.500us             2  \n",
      "void cask_cudnn::computeOffsetsKernel<false, false>(...         0.00%       0.000us         0.00%       0.000us       0.000us       2.000us         0.09%       2.000us       2.000us             1  \n",
      "                                      aten::log_softmax         0.00%      12.000us         1.22%      39.557ms      39.557ms       0.000us         0.00%       2.000us       2.000us             1  \n",
      "                                     aten::_log_softmax         0.00%      83.000us         1.22%      39.545ms      39.545ms       2.000us         0.09%       2.000us       2.000us             1  \n",
      "void (anonymous namespace)::softmax_warp_forward<flo...         0.00%       0.000us         0.00%       0.000us       0.000us       2.000us         0.09%       2.000us       2.000us             1  \n",
      "                                  cudaStreamIsCapturing         0.00%      11.000us         0.00%      11.000us       1.833us       0.000us         0.00%       0.000us       0.000us             6  \n",
      "                                             cudaMalloc         0.05%       1.568ms         0.05%       1.568ms     142.545us       0.000us         0.00%       0.000us       0.000us            11  \n",
      "                                     cudaGetDeviceCount         0.00%       0.000us         0.00%       0.000us       0.000us       0.000us         0.00%       0.000us       0.000us             1  \n",
      "                                   cudaDriverGetVersion         0.00%       0.000us         0.00%       0.000us       0.000us       0.000us         0.00%       0.000us       0.000us             1  \n",
      "                                cudaGetDeviceProperties         0.10%       3.184ms         0.10%       3.184ms       3.184ms       0.000us         0.00%       0.000us       0.000us             1  \n",
      "                              cudaStreamCreateWithFlags         7.81%     254.037ms         7.81%     254.037ms      15.877ms       0.000us         0.00%       0.000us       0.000us            16  \n",
      "                                        cudaMemsetAsync         0.00%      49.000us         0.00%      49.000us      16.333us       0.000us         0.00%       0.000us       0.000us             3  \n",
      "                                          cudaHostAlloc         0.04%       1.267ms         0.04%       1.267ms       1.267ms       0.000us         0.00%       0.000us       0.000us             1  \n",
      "                               cudaHostGetDevicePointer         0.00%       2.000us         0.00%       2.000us       2.000us       0.000us         0.00%       0.000us       0.000us             1  \n",
      "                                               cudaFree         3.00%      97.473ms         3.00%      97.473ms      19.495ms       0.000us         0.00%       0.000us       0.000us             5  \n",
      "                                   cudaGetSymbolAddress         0.11%       3.484ms         0.11%       3.484ms       1.742ms       0.000us         0.00%       0.000us       0.000us             2  \n",
      "                                        cudaEventRecord         0.00%      14.000us         0.00%      14.000us       7.000us       0.000us         0.00%       0.000us       0.000us             2  \n",
      "                                  cudaStreamGetPriority         0.00%       2.000us         0.00%       2.000us       1.000us       0.000us         0.00%       0.000us       0.000us             2  \n",
      "                       cudaDeviceGetStreamPriorityRange         0.00%       2.000us         0.00%       2.000us       1.000us       0.000us         0.00%       0.000us       0.000us             2  \n",
      "                                       cudaLaunchKernel        70.49%        2.294s        70.49%        2.294s     191.164ms       0.000us         0.00%       0.000us       0.000us            12  \n",
      "                                          aten::reshape         0.00%      13.000us         0.00%      23.000us      11.500us       0.000us         0.00%       0.000us       0.000us             2  \n",
      "                                             aten::view         0.00%      22.000us         0.00%      22.000us       7.333us       0.000us         0.00%       0.000us       0.000us             3  \n",
      "                                  aten::feature_dropout         0.00%       1.000us         0.00%       1.000us       0.500us       0.000us         0.00%       0.000us       0.000us             2  \n",
      "                                          aten::flatten         0.00%       7.000us         0.00%      19.000us      19.000us       0.000us         0.00%       0.000us       0.000us             1  \n",
      "                                                aten::t         0.00%      27.000us         0.00%      52.000us      26.000us       0.000us         0.00%       0.000us       0.000us             2  \n",
      "                                        aten::transpose         0.00%      15.000us         0.00%      25.000us      12.500us       0.000us         0.00%       0.000us       0.000us             2  \n",
      "                                       aten::as_strided         0.00%      10.000us         0.00%      10.000us       5.000us       0.000us         0.00%       0.000us       0.000us             2  \n",
      "          cudaOccupancyMaxActiveBlocksPerMultiprocessor         0.31%      10.118ms         0.31%      10.118ms       3.373ms       0.000us         0.00%       0.000us       0.000us             3  \n",
      "                                  cudaDeviceSynchronize         0.06%       1.967ms         0.06%       1.967ms       1.967ms       0.000us         0.00%       0.000us       0.000us             1  \n",
      "-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  \n",
      "Self CPU time total: 3.254s\n",
      "Self CUDA time total: 2.144ms\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "STAGE:2024-02-16 16:03:43 3073744:3073744 ActivityProfilerController.cpp:320] Completed Stage: Collection\n",
      "STAGE:2024-02-16 16:03:43 3073744:3073744 ActivityProfilerController.cpp:324] Completed Stage: Post Processing\n"
     ]
    }
   ],
   "source": [
    "model=model.cuda()\n",
    "sample_data=sample_data.cuda()\n",
    "with profile(activities=[\n",
    "        ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:\n",
    "    with record_function(\"model_inference\"):\n",
    "        model(sample_data)\n",
    "\n",
    "print(prof.key_averages().table(sort_by=\"cuda_time_total\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## tracing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "STAGE:2024-02-16 16:03:43 3073744:3073744 ActivityProfilerController.cpp:314] Completed Stage: Warm Up\n",
      "STAGE:2024-02-16 16:03:43 3073744:3073744 ActivityProfilerController.cpp:320] Completed Stage: Collection\n",
      "STAGE:2024-02-16 16:03:43 3073744:3073744 ActivityProfilerController.cpp:324] Completed Stage: Post Processing\n"
     ]
    }
   ],
   "source": [
    "with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:\n",
    "    model(sample_data)\n",
    "\n",
    "prof.export_chrome_trace(\"trace.json\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (Local)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
